{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:39.495862Z",
     "iopub.status.busy": "2021-08-12T07:43:39.495311Z",
     "iopub.status.idle": "2021-08-12T07:43:39.868620Z",
     "shell.execute_reply": "2021-08-12T07:43:39.868122Z",
     "shell.execute_reply.started": "2021-08-12T07:43:39.495814Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import cv2\n",
    "import glob\n",
    "\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:40.913269Z",
     "iopub.status.busy": "2021-08-12T07:43:40.912717Z",
     "iopub.status.idle": "2021-08-12T07:43:41.410698Z",
     "shell.execute_reply": "2021-08-12T07:43:41.410154Z",
     "shell.execute_reply.started": "2021-08-12T07:43:40.913221Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, sys, codecs, glob\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import cv2\n",
    "\n",
    "import torch\n",
    "torch.backends.cudnn.benchmark = False\n",
    "# torch.backends.cudnn.enabled = False\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:41.842011Z",
     "iopub.status.busy": "2021-08-12T07:43:41.841549Z",
     "iopub.status.idle": "2021-08-12T07:43:41.847822Z",
     "shell.execute_reply": "2021-08-12T07:43:41.847019Z",
     "shell.execute_reply.started": "2021-08-12T07:43:41.841972Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, img_group, transform):\n",
    "        self.img_path = img_path\n",
    "        self.transform = transform\n",
    "        self.group = img_group\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        return img, self.group[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:43.693140Z",
     "iopub.status.busy": "2021-08-12T07:43:43.692686Z",
     "iopub.status.idle": "2021-08-12T07:43:43.727664Z",
     "shell.execute_reply": "2021-08-12T07:43:43.727068Z",
     "shell.execute_reply.started": "2021-08-12T07:43:43.693101Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>label</th>\n",
       "      <th>path</th>\n",
       "      <th>group</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>008233.jpg</td>\n",
       "      <td>008233.jpg 006688.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/008233.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>006688.jpg</td>\n",
       "      <td>008233.jpg 006688.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/006688.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>000232.jpg</td>\n",
       "      <td>000232.jpg 003552.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/000232.jpg</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>003552.jpg</td>\n",
       "      <td>000232.jpg 003552.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/003552.jpg</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>000814.jpg</td>\n",
       "      <td>000814.jpg 013765.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/000814.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>013765.jpg</td>\n",
       "      <td>000814.jpg 013765.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/013765.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>001429.jpg</td>\n",
       "      <td>001429.jpg 014834.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/001429.jpg</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>014834.jpg</td>\n",
       "      <td>001429.jpg 014834.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/014834.jpg</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>012795.jpg</td>\n",
       "      <td>012795.jpg 015860.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/012795.jpg</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>015860.jpg</td>\n",
       "      <td>012795.jpg 015860.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/015860.jpg</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         name                  label                           path  group  \\\n",
       "0  008233.jpg  008233.jpg 006688.jpg  ./电商图像检索_数据集/train/008233.jpg      0   \n",
       "1  006688.jpg  008233.jpg 006688.jpg  ./电商图像检索_数据集/train/006688.jpg      0   \n",
       "2  000232.jpg  000232.jpg 003552.jpg  ./电商图像检索_数据集/train/000232.jpg      1   \n",
       "3  003552.jpg  000232.jpg 003552.jpg  ./电商图像检索_数据集/train/003552.jpg      1   \n",
       "4  000814.jpg  000814.jpg 013765.jpg  ./电商图像检索_数据集/train/000814.jpg      2   \n",
       "5  013765.jpg  000814.jpg 013765.jpg  ./电商图像检索_数据集/train/013765.jpg      2   \n",
       "6  001429.jpg  001429.jpg 014834.jpg  ./电商图像检索_数据集/train/001429.jpg      3   \n",
       "7  014834.jpg  001429.jpg 014834.jpg  ./电商图像检索_数据集/train/014834.jpg      3   \n",
       "8  012795.jpg  012795.jpg 015860.jpg  ./电商图像检索_数据集/train/012795.jpg      4   \n",
       "9  015860.jpg  012795.jpg 015860.jpg  ./电商图像检索_数据集/train/015860.jpg      4   \n",
       "\n",
       "   fold  \n",
       "0     0  \n",
       "1     0  \n",
       "2     1  \n",
       "3     1  \n",
       "4     2  \n",
       "5     2  \n",
       "6     3  \n",
       "7     3  \n",
       "8     4  \n",
       "9     4  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = pd.read_csv('./电商图像检索_数据集/train.csv')\n",
    "train_df['path'] = './电商图像检索_数据集/train/' + train_df['name']\n",
    "\n",
    "train_df['group'] = pd.factorize(train_df['label'])[0]\n",
    "train_df['fold'] = train_df['group'] % 5\n",
    "\n",
    "train_df = train_df.sort_values(by='group')\n",
    "train_df.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:46.841825Z",
     "iopub.status.busy": "2021-08-12T07:43:46.841400Z",
     "iopub.status.idle": "2021-08-12T07:43:46.852833Z",
     "shell.execute_reply": "2021-08-12T07:43:46.851866Z",
     "shell.execute_reply.started": "2021-08-12T07:43:46.841791Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tr_path = train_df[train_df['fold'] != 0]['path'].values\n",
    "tr_label = train_df[train_df['fold'] != 0]['group']\n",
    "tr_label = pd.factorize(tr_label)[0]\n",
    "\n",
    "val_path = train_df[train_df['fold'] == 0]['path'].values\n",
    "val_label = train_df[train_df['fold'] == 0]['group'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:43:48.772175Z",
     "iopub.status.busy": "2021-08-12T07:43:48.771658Z",
     "iopub.status.idle": "2021-08-12T07:43:48.778807Z",
     "shell.execute_reply": "2021-08-12T07:43:48.778040Z",
     "shell.execute_reply.started": "2021-08-12T07:43:48.772134Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1805"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tr_label.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:44:50.686192Z",
     "iopub.status.busy": "2021-08-12T07:44:50.685621Z",
     "iopub.status.idle": "2021-08-12T07:44:50.697021Z",
     "shell.execute_reply": "2021-08-12T07:44:50.695847Z",
     "shell.execute_reply.started": "2021-08-12T07:44:50.686144Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(tr_path, tr_label,\n",
    "                        transforms.Compose([\n",
    "                        transforms.Resize((300, 300)),\n",
    "                        transforms.RandomHorizontalFlip(),\n",
    "                        transforms.RandomVerticalFlip(),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=10, shuffle=True, num_workers=5,\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(val_path, val_label,\n",
    "                        transforms.Compose([\n",
    "                        transforms.Resize((300, 300)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=10, shuffle=False, num_workers=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:44:51.821872Z",
     "iopub.status.busy": "2021-08-12T07:44:51.821321Z",
     "iopub.status.idle": "2021-08-12T07:44:51.834648Z",
     "shell.execute_reply": "2021-08-12T07:44:51.834004Z",
     "shell.execute_reply.started": "2021-08-12T07:44:51.821825Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class ArcModule(nn.Module):\n",
    "    def __init__(self, in_features, out_features, s = 10, m = 0.1):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.s = s\n",
    "        self.m = m\n",
    "        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))\n",
    "        nn.init.xavier_normal_(self.weight)\n",
    "\n",
    "        self.cos_m = math.cos(m)\n",
    "        self.sin_m = math.sin(m)\n",
    "        self.th = torch.tensor(math.cos(math.pi - m))\n",
    "        self.mm = torch.tensor(math.sin(math.pi - m) * m)\n",
    "\n",
    "    def forward(self, inputs, labels):\n",
    "        cos_th = F.linear(inputs, F.normalize(self.weight))\n",
    "        cos_th = cos_th.clamp(-1, 1)\n",
    "        sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2))\n",
    "        cos_th_m = cos_th * self.cos_m - sin_th * self.sin_m\n",
    "        # print(type(cos_th), type(self.th), type(cos_th_m), type(self.mm))\n",
    "        cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm)\n",
    "        \n",
    "        cond_v = cos_th - self.th\n",
    "        cond = cond_v <= 0\n",
    "        cos_th_m[cond] = (cos_th - self.mm)[cond]\n",
    "\n",
    "        if labels.dim() == 1:\n",
    "            labels = labels.unsqueeze(-1)\n",
    "        onehot = torch.zeros(cos_th.size()).cuda()\n",
    "        labels = labels.type(torch.LongTensor).cuda()\n",
    "        onehot.scatter_(1, labels, 1.0)\n",
    "        outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th\n",
    "        outputs = outputs * self.s\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:44:55.076500Z",
     "iopub.status.busy": "2021-08-12T07:44:55.075902Z",
     "iopub.status.idle": "2021-08-12T07:44:55.735223Z",
     "shell.execute_reply": "2021-08-12T07:44:55.734566Z",
     "shell.execute_reply.started": "2021-08-12T07:44:55.076452Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=1536, out_features=137, bias=True)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import timm\n",
    "timm.create_model('efficientnet_b3', num_classes=137, \n",
    "                          pretrained=True).classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:45:02.346374Z",
     "iopub.status.busy": "2021-08-12T07:45:02.345796Z",
     "iopub.status.idle": "2021-08-12T07:45:04.339930Z",
     "shell.execute_reply": "2021-08-12T07:45:04.339046Z",
     "shell.execute_reply.started": "2021-08-12T07:45:02.346327Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "        model = timm.create_model('efficientnet_b3', num_classes=137, \n",
    "                          pretrained=True)\n",
    "        model.classifier = torch.nn.Identity()\n",
    "        self.model = model\n",
    "        self.margin = ArcModule(in_features=1536, \n",
    "                                out_features = 1806)\n",
    "        \n",
    "    def forward(self, img, labels=None):        \n",
    "        feat = self.model(img)\n",
    "        \n",
    "        feat = F.normalize(feat)\n",
    "        if labels is not None:\n",
    "            return self.margin(feat, labels)\n",
    "        return feat\n",
    "    \n",
    "model = XunFeiNet().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:45:04.780925Z",
     "iopub.status.busy": "2021-08-12T07:45:04.780362Z",
     "iopub.status.idle": "2021-08-12T07:45:04.791315Z",
     "shell.execute_reply": "2021-08-12T07:45:04.790634Z",
     "shell.execute_reply.started": "2021-08-12T07:45:04.780876Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        output = model(input, target)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if i % 40 == 0:\n",
    "            print(loss.item())\n",
    "            \n",
    "def validate(val_loader, model):\n",
    "    model.eval()\n",
    "    \n",
    "    val_feats = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            val_feats.append(output.data.cpu().numpy())\n",
    "    return val_feats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:45:06.115811Z",
     "iopub.status.busy": "2021-08-12T07:45:06.115382Z",
     "iopub.status.idle": "2021-08-12T07:45:06.121782Z",
     "shell.execute_reply": "2021-08-12T07:45:06.120796Z",
     "shell.execute_reply.started": "2021-08-12T07:45:06.115773Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def set_iou(label, predict):\n",
    "    interset = set(label.split()) &  set(predict.split())\n",
    "    unionset = set(label.split()) | set(predict.split())\n",
    "    return len(interset) *1.0 / len(unionset) *1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:45:07.180552Z",
     "iopub.status.busy": "2021-08-12T07:45:07.179994Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0\n",
      "8.410222053527832\n",
      "8.738941192626953\n",
      "8.133112907409668\n",
      "8.43204402923584\n",
      "8.73010540008545\n",
      "8.435385704040527\n",
      "8.359577178955078\n",
      "8.50145149230957\n",
      "7.699382781982422\n",
      "7.894308567047119\n",
      "7.7224531173706055\n",
      "6.893548011779785\n",
      "7.043992042541504\n",
      "7.310341835021973\n",
      "Val 0.9126315789473685 0.4776602828153394\n",
      "Epoch:  1\n",
      "7.161893367767334\n",
      "7.050235748291016\n",
      "7.080412864685059\n",
      "6.716268062591553\n",
      "6.632183074951172\n",
      "6.919695854187012\n",
      "6.6141815185546875\n",
      "7.120757102966309\n",
      "6.330360412597656\n",
      "6.984335899353027\n",
      "6.274984359741211\n",
      "6.637211799621582\n",
      "5.890048027038574\n",
      "5.753139972686768\n",
      "Val 0.8352631578947368 0.544906177740984\n",
      "Epoch:  2\n",
      "4.990122318267822\n",
      "5.206655502319336\n",
      "5.629385471343994\n",
      "5.0866241455078125\n",
      "5.858253479003906\n",
      "5.120569229125977\n",
      "4.58829402923584\n",
      "4.810513973236084\n",
      "4.749154090881348\n",
      "4.858401298522949\n",
      "4.933943748474121\n",
      "3.8157520294189453\n",
      "5.059332370758057\n",
      "5.189235687255859\n",
      "Val 0.7063157894736842 0.5878100637508792\n",
      "Epoch:  3\n",
      "2.842712879180908\n",
      "4.265428066253662\n",
      "3.71002459526062\n",
      "4.222585201263428\n",
      "4.116498947143555\n",
      "2.883312940597534\n",
      "4.075376987457275\n",
      "4.138993740081787\n",
      "3.350064754486084\n",
      "3.195556163787842\n",
      "3.3333733081817627\n",
      "3.740750789642334\n",
      "2.989471912384033\n",
      "2.817662477493286\n",
      "Val 0.7321052631578947 0.6047961329402526\n",
      "Epoch:  4\n",
      "2.6874446868896484\n",
      "2.84551739692688\n",
      "3.3532352447509766\n",
      "2.9507999420166016\n",
      "3.2798855304718018\n",
      "2.8438029289245605\n",
      "2.444345474243164\n",
      "3.2179367542266846\n",
      "3.0540945529937744\n",
      "3.150499105453491\n",
      "2.000516414642334\n",
      "2.669219493865967\n",
      "2.115185260772705\n",
      "2.4844212532043457\n",
      "Val 0.6289473684210526 0.6251751995288805\n",
      "Epoch:  5\n",
      "1.9250915050506592\n",
      "1.9584401845932007\n",
      "1.6891781091690063\n",
      "1.828899621963501\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.0003)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.85)\n",
    "best_acc = 0.0\n",
    "\n",
    "for epoch in range(20):\n",
    "    print('Epoch: ', epoch)\n",
    "\n",
    "    train(train_loader, model, criterion, optimizer, epoch)\n",
    "    scheduler.step()\n",
    "    \n",
    "    val_feats = validate(val_loader, model)\n",
    "    val_feats = np.vstack(val_feats)\n",
    "    # val_feats = normalize(val_feats)\n",
    "    \n",
    "    val_distance = []\n",
    "    for feat in val_feats:\n",
    "        dis = np.dot(feat, val_feats.T)\n",
    "        val_distance.append(dis)\n",
    "        \n",
    "    best_threahold, best_f1 = 0, 0\n",
    "    for threahold in np.linspace(0.5, 0.99, 20):\n",
    "        val_submit = []\n",
    "        for dis in val_distance[:]:\n",
    "            pred = np.where(dis > threahold)[0]\n",
    "            if len(pred) == 1:\n",
    "                ids = dis.argsort()[::-1]\n",
    "                pred = [x for x in ids[dis[ids] > 0.1]][:2]\n",
    "\n",
    "            val_submit.append(pred)\n",
    "\n",
    "        val_f1s = []\n",
    "        for x, pred in zip(val_label, val_submit):\n",
    "            label = np.where(val_label == x)[0]\n",
    "            val_f1 = len(set(pred) & set(label)) / len(set(pred) | set(label)) \n",
    "            val_f1s.append(val_f1)\n",
    "\n",
    "        if best_f1 < np.mean(val_f1s):\n",
    "            best_f1 = np.mean(val_f1s)\n",
    "            best_threahold = threahold\n",
    "\n",
    "    print('Val', best_threahold, best_f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:30:49.982013Z",
     "iopub.status.busy": "2021-08-12T07:30:49.981425Z",
     "iopub.status.idle": "2021-08-12T07:30:50.000555Z",
     "shell.execute_reply": "2021-08-12T07:30:49.999403Z",
     "shell.execute_reply.started": "2021-08-12T07:30:49.981964Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_path = glob.glob('./电商图像检索_数据集/test/*')\n",
    "test_path.sort()\n",
    "test_path = np.array(test_path)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path, [0]*len(test_path),\n",
    "                        transforms.Compose([\n",
    "                        transforms.Resize((300, 300)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=50, shuffle=False, num_workers=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:30:51.716488Z",
     "iopub.status.busy": "2021-08-12T07:30:51.715890Z",
     "iopub.status.idle": "2021-08-12T07:31:06.333680Z",
     "shell.execute_reply": "2021-08-12T07:31:06.332992Z",
     "shell.execute_reply.started": "2021-08-12T07:30:51.716440Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "test_feats = []\n",
    "with torch.no_grad():\n",
    "    for data in test_loader:\n",
    "        data = data[0].cuda()\n",
    "        feat = model(data)\n",
    "        test_feats.append(feat.data.cpu().numpy())\n",
    "        \n",
    "test_feats = np.vstack(test_feats)\n",
    "test_feats = normalize(test_feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:31:50.805616Z",
     "iopub.status.busy": "2021-08-12T07:31:50.805042Z",
     "iopub.status.idle": "2021-08-12T07:31:53.987143Z",
     "shell.execute_reply": "2021-08-12T07:31:53.986348Z",
     "shell.execute_reply.started": "2021-08-12T07:31:50.805568Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = []\n",
    "for path, feat in zip(test_path[:], test_feats[:]):\n",
    "    dis = np.dot(feat, test_feats.T)\n",
    "    pred = [x.split('/')[-1] for x in test_path[np.where(dis > 0.7)[0]]]\n",
    "    if len(pred) == 1:\n",
    "        ids = dis.argsort()[::-1]\n",
    "        pred = [x.split('/')[-1] for x in test_path[ids[:2]]]\n",
    "    \n",
    "    test_submit.append([\n",
    "        path.split('/')[-1],\n",
    "        pred\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:31:53.988996Z",
     "iopub.status.busy": "2021-08-12T07:31:53.988470Z",
     "iopub.status.idle": "2021-08-12T07:31:54.015422Z",
     "shell.execute_reply": "2021-08-12T07:31:54.014767Z",
     "shell.execute_reply.started": "2021-08-12T07:31:53.988963Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = pd.DataFrame(test_submit, columns=['name', 'label'])\n",
    "test_submit['label'] = test_submit['label'].apply(lambda x: ' '.join(x))\n",
    "test_submit\n",
    "test_submit.to_csv('submit.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:34:00.668758Z",
     "iopub.status.busy": "2021-08-12T07:34:00.668187Z",
     "iopub.status.idle": "2021-08-12T07:34:10.987002Z",
     "shell.execute_reply": "2021-08-12T07:34:10.986238Z",
     "shell.execute_reply.started": "2021-08-12T07:34:00.668710Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 做了扩展查询\n",
    "test_submit = []\n",
    "for path, feat in zip(test_path[:], test_feats[:]):\n",
    "    dis = np.dot(feat, test_feats.T)\n",
    "\n",
    "    feat_qe = np.multiply(dis[np.argsort(dis)[::-1][:2]].reshape(2, -1),\n",
    "            test_feats[np.argsort(dis)[::-1][:2]]).mean(0)\n",
    "    dis = np.dot(feat_qe, test_feats.T)\n",
    "    \n",
    "    pred = [x.split('/')[-1] for x in test_path[np.where(dis > 0.7)[0]]]\n",
    "    if len(pred) <= 1:\n",
    "        ids = dis.argsort()[::-1]\n",
    "        pred = [x.split('/')[-1] for x in test_path[ids[:2]]]\n",
    "    \n",
    "    test_submit.append([\n",
    "        path.split('/')[-1],\n",
    "        pred\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T07:34:19.549065Z",
     "iopub.status.busy": "2021-08-12T07:34:19.548493Z",
     "iopub.status.idle": "2021-08-12T07:34:19.563492Z",
     "shell.execute_reply": "2021-08-12T07:34:19.562998Z",
     "shell.execute_reply.started": "2021-08-12T07:34:19.549017Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = pd.DataFrame(test_submit, columns=['name', 'label'])\n",
    "test_submit['label'] = test_submit['label'].apply(lambda x: ' '.join(x))\n",
    "test_submit\n",
    "test_submit.to_csv('submit_qe.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
